# Author: WY
# Date: 2023/5/17 16:27

import numpy as np
import model.GCN_TCN_Model as Model
import model.GCN_TCN_Model_Multi as Model_m
import model.DataLoader as loader
import pandas as pd

if __name__ == '__main__':
    # 读取文件列表
    dataList = pd.read_json('./data/DataList.json')
    columns = ["大气温度（摄氏度）", "水平大气压（毫米汞柱）", "海平面大气压（毫米汞柱）", "气压变化趋势", "相对湿度", "风向", "平均风速（m/s）",
               "过去12小时内最低气温", "过去12小时内最高气温", "水平能见度（km）", "露点温度（摄氏度）", "降水量（毫米）", "到达规定降水量的时间"]
    # 读取样本文件并整合到一个list中
    print("载入数据集")
    list = [pd.read_csv('./data/file/' + name)[columns] for name in dataList["name"]]

    # 将数据载入到DataLoader中
    # dataloader_train, dataloader_test = loader.get_loader(split=Model.split, batch_size=Model.batch_size + Model.pre_step, data=list)

    # Model.train(dataloader_train)
    # Model.bias_test(dataloader_train)
    # Model.variance_test(dataloader_test)

    # dataloader_train, dataloader_test = loader.get_loader(split=Model_m.split, batch_size=Model_m.batch_size + 8, data=list)
    # Model_m.test(dataloader_train, 8)

    for step in range(10, 11):
        dataloader_train, dataloader_test = loader.get_loader(split=Model_m.split, batch_size=Model_m.batch_size + step, data=list)
        Model_m.performance_record(step, dataloader_train, dataloader_test)


